{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6nK8DDl0ReVb"
      },
      "source": [
        "Copyright 2022 Google LLC.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qLtViTDjRnwz"
      },
      "outputs": [],
      "source": [
        "#@title Example Dataset Preparation\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bcra0xwZRqj1"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5G03t3baW46i"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6WH0QcJHW_Js"
      },
      "outputs": [],
      "source": [
        "import cv2\n",
        "import pandas\n",
        "import PIL\n",
        "import pycocotools\n",
        "import scipy"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lUJmKisvs9cE"
      },
      "outputs": [],
      "source": [
        "# Ensure we are working from google_research/ directory (not google_research/factors_of_influence):\n",
        "# cd ../"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0xmO5PPmXK2_"
      },
      "outputs": [],
      "source": [
        "from factors_of_influence.fids import fids_tfds_builders as fids"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Bj7P22NIEDs"
      },
      "source": [
        "# Prepare a single dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sdfguAlCXlob"
      },
      "outputs": [],
      "source": [
        "# For debugging consider using data_dir='/tmp/fids/tfds' which will save the tfds\n",
        "# dataset in the specified directory instead of the default tfds data directory.\n",
        "ds_build = tfds.builder('fids_kitti_segmentation')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OWkQ-HHuX7KX"
      },
      "outputs": [],
      "source": [
        "ds_build.download_and_prepare()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EqLHVfnqYJM0"
      },
      "outputs": [],
      "source": [
        "ds_info = ds_build.info\n",
        "ds = ds_build.as_dataset(split='train')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PB8zwouWYQy2"
      },
      "outputs": [],
      "source": [
        "print(ds_info.metadata['name'],ds_info.metadata['config'])\n",
        "for f in ds_info.metadata['features']:\n",
        "    print(f, len(ds_info.metadata['features'][f]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "I1MgVuP6YURs"
      },
      "outputs": [],
      "source": [
        "for ex in ds.take(1).as_numpy_iterator():\n",
        "    print(ex.keys())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q7j2CbfMYW39"
      },
      "outputs": [],
      "source": [
        "f, axs = plt.subplots(1, 2, figsize=(20, 10))\n",
        "axs[0].imshow(ex['image'])\n",
        "axs[1].imshow(ex['segmentation'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4TE6R4quHa_K"
      },
      "source": [
        "## Subsequent calls\n",
        "Subsequent calls to the dataset can use the `tfds.load` interface as follows:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tjQ34qMvHcso"
      },
      "outputs": [],
      "source": [
        "ds, ds_info = tfds.load('fids_kitti_segmentation', split='train', with_info=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9RotkXZiIInw"
      },
      "source": [
        "# Prepare all datasets and all configs\n",
        "Some datasets have multiple configs, here we show how these could be generated."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qBLouN9aIJ2G"
      },
      "outputs": [],
      "source": [
        "def _yield_all_datasets_and_configs(fids_datasets: list[str] = None):\n",
        "  \"\"\"Yield all datasets and configs, given a list of datasets.\"\"\"\n",
        "  if fids_datasets is None:  # Use all fids_ datasets:\n",
        "    fids_datasets = [ds for ds in tfds.list_builders() if ds.startswith('fids_')]\n",
        "  \n",
        "  for dataset_name in fids_datasets:\n",
        "    ds_build = tfds.builder(dataset_name)\n",
        "    all_configs = [c.name for c in ds_build.BUILDER_CONFIGS]\n",
        "    for config_name in all_configs:            \n",
        "      yield dataset_name, config_name"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QGxSVVK4tK7d"
      },
      "outputs": [],
      "source": [
        "for ds_name, config_name in _yield_all_datasets_and_configs():\n",
        "  print(f'{ds_name:25s} | {config_name:20s}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "y1GO8Z7yIRuv"
      },
      "outputs": [],
      "source": [
        "# Warning: This starts to convert all datasets to TFDS sequentially!\n",
        "for ds_name, config_name in _yield_all_datasets_and_configs():\n",
        "  print(f'{ds_name:25s} | {config_name:20s}')\n",
        "  ds_build = tfds.builder(f'{dataset_name}/{config_name}')\n",
        "  ds_build.download_and_prepare()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "example_dataset_preparation",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "19rvIFdBjs9Rw3bHi1T1gT-txZq7QCMVN",
          "timestamp": 1657643120025
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
